package com.ruoyi.framework.aspectj

import com.ruoyi.common.exception.ServiceException
import com.ruoyi.common.utils.*
import com.ruoyi.common.utils.ip.IpUtils
import com.ruoyi.framework.aspectj.RateLimiterAspect
import com.ruoyi.framework.aspectj.lang.annotation.RateLimiter
import com.ruoyi.framework.aspectj.lang.enums.LimitType
import org.aspectj.lang.JoinPoint
import org.aspectj.lang.annotation.*
import org.aspectj.lang.reflect.MethodSignature
import org.slf4j.LoggerFactory
import org.springframework.beans.factory.annotation.Autowired
import org.springframework.data.redis.core.RedisTemplate
import org.springframework.data.redis.core.script.RedisScript
import org.springframework.stereotype.Component

/**
 * 限流处理
 *
 * @author ruoyi
 */
@Aspect
@Component
class RateLimiterAspect {
    private var redisTemplate: RedisTemplate<Any, Any>? = null
    private var limitScript: RedisScript<Long>? = null

    @Autowired
    fun setRedisTemplate1(redisTemplate: RedisTemplate<Any, Any>?) {
        this.redisTemplate = redisTemplate
    }

    @Autowired
    fun setLimitScript(limitScript: RedisScript<Long>) {
        this.limitScript = limitScript
    }

    @Before("@annotation(rateLimiter)")
    @Throws(Throwable::class)
    fun doBefore(point: JoinPoint, rateLimiter: RateLimiter) {
        val key: String = rateLimiter.key
        val time: Int = rateLimiter.time
        val count: Int = rateLimiter.count
        val combineKey = getCombineKey(rateLimiter, point)
        val keys = listOf<Any>(combineKey)
        try {
            val number = redisTemplate!!.execute(limitScript!!, keys, count, time)
            if (StringUtils.isNull(number) || number.toInt() > count) {
                throw ServiceException("访问过于频繁，请稍候再试")
            }
            log.info("限制请求'{}',当前请求'{}',缓存key'{}'", count, number.toInt(), key)
        } catch (e: ServiceException) {
            throw e
        } catch (e: Exception) {
            throw RuntimeException("服务器限流异常，请稍候再试")
        }
    }

    fun getCombineKey(rateLimiter: RateLimiter, point: JoinPoint): String {
        val stringBuffer = StringBuffer(rateLimiter.key)
        if (rateLimiter.limitType == LimitType.IP) {
            stringBuffer.append(IpUtils.getIpAddr(ServletUtils.request)).append("-")
        }
        val signature = point.signature as MethodSignature
        val method = signature.method
        val targetClass = method.declaringClass
        stringBuffer.append(targetClass.name).append("-").append(method.name)
        return stringBuffer.toString()
    }

    companion object {
        private val log = LoggerFactory.getLogger(RateLimiterAspect::class.java)
    }
}
